Back to Article
Article Notebook
Download Source

Diptera wing classification using Topological Data Analysis

Authors
Affiliation

Guilherme Vituri F. Pinto

Universidade Estadual Paulista

Sergio Ura

Northon

Published

February 25, 2026

Abstract

We apply tools from Topological Data Analysis (TDA) to classify Diptera families based on wing venation patterns. Using two complementary filtration strategies — Vietoris-Rips on point clouds and radial filtrations on wing images — we extract both H0 and H1 topological features via extended summary statistics and compare classifiers via leave-one-out cross-validation. We focus on interpretable models (LDA, Decision Trees) to identify explainable topological criteria that distinguish families.

Keywords

Topological Data Analysis, Persistent homology, Diptera classification, Wing venation

In [1]:
using TDAfly, TDAfly.Preprocessing, TDAfly.TDA, TDAfly.Analysis
using Images: mosaicview, Gray
using Plots: plot, display, heatmap, scatter, bar
using StatsPlots: boxplot
using PersistenceDiagrams
using PersistenceDiagrams: BettiCurve, Landscape, PersistenceImage
using DataFrames
using Distances: euclidean
using LIBSVM
using StatsBase: mean

1 Introduction

The order Diptera (true flies) comprises over 150,000 described species across more than 150 families. Wing venation patterns are a classical diagnostic character in Diptera systematics: the arrangement, branching and connectivity of veins varies markedly across families and provides a natural morphological signature.

In this work, we apply Topological Data Analysis (TDA) to the problem of classifying Diptera families from wing images. TDA provides a framework for extracting shape descriptors that are robust to continuous deformations — exactly the kind of invariance desirable when comparing biological structures that vary in scale, orientation and minor deformations across individuals.

We employ two complementary filtration strategies:

  1. Vietoris-Rips filtration on point-cloud samples of wing silhouettes — captures global loop structure
  2. Radial filtration from the wing centroid to the periphery — captures how vein topology is organized from center outward

For each filtration, we compute both H0 (connected components / vein branching) and H1 (loops / enclosed cells) persistence, then extract extended summary statistics (19 interpretable features per diagram) and classify using simple, explainable models (LDA, Decision Trees, Random Forests). The goal is to find interpretable topological criteria for family identification.

Why only two filtrations?

We initially tested five filtration strategies including directional height filtrations (8 directions), Euclidean Distance Transform (EDT), and grayscale cubical filtrations. However: (a) directional filtrations are noise-sensitive — in images with isolated pixels and incomplete vein segmentation, each sweep direction creates spurious topological features; (b) EDT produces trivial persistence on binarized images where veins are ~1 pixel wide; (c) cubical (grayscale) filtrations are meaningless on already-binarized black-and-white images. See NOTES.md for details on discarded methods.

2 Methods

2.1 Data loading and preprocessing

All images are in the images/processed directory. For each image, we load it, apply a Gaussian blur (to close small gaps in the wing membrane and keep it connected), crop to the bounding box, and resize to 150 pixels of height.

In [2]:
all_paths = readdir("images/processed", join = true)
all_filenames = basename.(all_paths) .|> (x -> replace(x, ".png" => ""))

function extract_family(name)
    family_raw = lowercase(split(name, r"[\s\-]")[1])
    if family_raw in ("bibionidae", "biobionidae")
        return "Bibionidae"
    elseif family_raw in ("sciaridae", "scaridae")
        return "Sciaridae"
    elseif family_raw == "simulidae"
        return "Simuliidae"
    else
        return titlecase(family_raw)
    end
end

function canonical_id(name)
    family = extract_family(name)
    parts = split(name, r"[\s\-]")
    number = parts[end]
    "$(family)-$(number)"
end

# Deduplicate (space vs hyphen variants of the same file)
seen = Set{String}()
keep_idx = Int[]
for (i, fname) in enumerate(all_filenames)
    cid = canonical_id(fname)
    if !(cid in seen)
        push!(seen, cid)
        push!(keep_idx, i)
    end
end

paths = all_paths[keep_idx]
species = all_filenames[keep_idx]
families = extract_family.(species)

individuals = map(species) do specie
    parts = split(specie, r"[\s\-]")
    string(extract_family(specie)[1]) * "-" * parts[end]
end

println("Total images after deduplication: $(length(paths))")
println("Families: ", sort(unique(families)))
println("\nSamples per family:")
for f in sort(unique(families))
    println("  $(f): $(count(==(f), families))")
end
Total images after deduplication: 70
Families: ["Asilidae", "Bibionidae", "Ceratopogonidae", "Chironomidae", "Rhagionidae", "Sciaridae", "Simuliidae", "Tabanidae", "Tipulidae"]

Samples per family:
  Asilidae: 8
  Bibionidae: 6
  Ceratopogonidae: 8
  Chironomidae: 8
  Rhagionidae: 4
  Sciaridae: 6
  Simuliidae: 7
  Tabanidae: 11
  Tipulidae: 12

2.1.1 Excluding small families

Families with fewer than 3 samples (e.g. Pelecorhynchidae with \(n=2\)) can distort cross-validation results—a single misclassification changes accuracy by 50%. We provide a filtered version and run the analysis both ways.

In [3]:
MIN_FAMILY_SIZE = 3
family_counts = Dict(f => count(==(f), families) for f in unique(families))
small_families = [f for (f, c) in family_counts if c < MIN_FAMILY_SIZE]

if !isempty(small_families)
    println("Families with < $MIN_FAMILY_SIZE samples (excluded from filtered analysis):")
    for f in sort(small_families)
        println("  $(f): $(family_counts[f]) samples")
    end
end

# Build filtered indices
keep_filtered = [i for i in eachindex(families) if family_counts[families[i]] >= MIN_FAMILY_SIZE]
paths_filtered = paths[keep_filtered]
species_filtered = species[keep_filtered]
families_filtered = families[keep_filtered]
individuals_filtered = individuals[keep_filtered]

println("\nFiltered dataset: $(length(keep_filtered)) samples, $(length(unique(families_filtered))) families")

Filtered dataset: 70 samples, 9 families
In [4]:
wings = load_wing.(paths, blur = 1.8)
Xs = map(wings) do w
    image_to_r2(w; threshold=0.08, ensure_connected = true, connectivity = 8)
end;
In [5]:
wings[1]
In [6]:
scatter(Xs[5] .|> first, Xs[5] .|> last)
In [7]:
mosaicview(wings, ncol = 6, fillvalue = 1)

2.2 Example: forcing connectivity on 5 wings

The chunk below selects 5 wings (prioritizing those with the largest number of disconnected components before correction), then compares the binary pixel set before and after connect_pixel_components.

In [8]:
threshold_conn = 0.1
conn = 8

component_count_before = map(wings) do w
    ids0 = findall_ids(>(threshold_conn), image_to_array(w))
    length(pixel_components(ids0; connectivity = conn))
end

demo_idx = sortperm(component_count_before, rev = true)[1:min(5, length(wings))]

function ids_to_mask(ids)
    isempty(ids) && return zeros(Float32, 1, 1)
    xs = first.(ids)
    ys = last.(ids)
    M = zeros(Float32, maximum(xs), maximum(ys))
    for p in ids
        M[p[1], p[2]] = 1f0
    end
    M
end

demo_connectivity_df = DataFrame(
    sample = String[],
    n_components_before = Int[],
    n_components_after = Int[],
    n_pixels_before = Int[],
    n_pixels_after = Int[],
)

panel_plots = Any[]
for idx in demo_idx
    ids_before = findall_ids(>(threshold_conn), image_to_array(wings[idx]))
    ids_after = connect_pixel_components(ids_before; connectivity = conn)

    n_before = length(pixel_components(ids_before; connectivity = conn))
    n_after = length(pixel_components(ids_after; connectivity = conn))

    push!(demo_connectivity_df, (
        species[idx],
        n_before,
        n_after,
        length(ids_before),
        length(ids_after),
    ))

    M_before = ids_to_mask(ids_before)
    M_after = ids_to_mask(ids_after)

    p_before = heatmap(
        M_before[end:-1:1, :],
        color = :grays,
        colorbar = false,
        legend = false,
        aspect_ratio = :equal,
        xticks = false,
        yticks = false,
        title = "Before: $(species[idx])\ncomponents = $(n_before)",
    )

    p_after = heatmap(
        M_after[end:-1:1, :],
        color = :grays,
        colorbar = false,
        legend = false,
        aspect_ratio = :equal,
        xticks = false,
        yticks = false,
        title = "After: $(species[idx])\ncomponents = $(n_after)",
    )

    push!(panel_plots, p_before)
    push!(panel_plots, p_after)
end

plot(panel_plots..., layout = (length(demo_idx), 2), size = (900, 260 * length(demo_idx)))
In [9]:
demo_connectivity_df
5×5 DataFrame
Row sample n_components_before n_components_after n_pixels_before n_pixels_after
String Int64 Int64 Int64 Int64
1 simulidae 27 101 1 9233 9568
2 biobionidae 9 88 1 9425 9605
3 simulidae 26 80 1 11556 11778
4 chironomidae 19 75 1 11504 11763
5 simulidae 24 71 1 9019 9189

3 Topological feature extraction

We compute persistent homology using two filtration strategies. For the Vietoris-Rips filtration on connected point clouds, H0 is uninformative (single infinite bar), so we use only H1. For the radial filtration (computed via sublevel-set persistence on the pixel grid), H0 is highly informative — it captures when disconnected vein segments merge as the filtration parameter grows, directly encoding vein count and branching patterns. We therefore compute both H0 and H1 for the radial filtration.

What is persistent homology?

Persistent homology is the main tool of TDA. Given a shape or dataset, it tracks how topological features — connected components (dimension 0), loops (dimension 1), voids (dimension 2), etc. — appear and disappear as we “grow” the shape through a filtration parameter. Each feature has a birth time (when it appears) and a death time (when it gets filled in). The collection of all (birth, death) pairs is called a persistence diagram. Features with long lifetimes (high persistence = death \(-\) birth) represent genuine topological structure, while short-lived features are typically noise.

3.1 Strategy 1: Vietoris-Rips filtration on point clouds

Vietoris-Rips filtration

Given a set of points in \(\mathbb{R}^n\), the Vietoris-Rips complex at scale \(\varepsilon\) connects any subset of points that are pairwise within distance \(\varepsilon\). As \(\varepsilon\) increases from 0, we obtain a nested sequence of simplicial complexes — the Rips filtration. This is the most common filtration in TDA for point-cloud data. It is computationally expensive (since it must consider all pairwise distances), which is why we subsample the point clouds.

We sample 750 points from each wing silhouette using farthest-point sampling (which ensures good coverage of the shape), then compute 1-dimensional Rips persistence:

In [10]:
samples = Vector{Any}(undef, length(Xs))
Threads.@threads for i in eachindex(Xs)
    samples[i] = farthest_points_sample(Xs[i], 750)
end
In [11]:
pds_rips = @showprogress map(samples) do s
    rips_pd_1d(s, cutoff = 5, threshold = 200)
end;
Progress:   3%|█▏                                       |  ETA: 0:01:38

Progress:  20%|████████▎                                |  ETA: 0:00:18

Progress:  36%|██████████████▋                          |  ETA: 0:00:10

Progress:  57%|███████████████████████▍                 |  ETA: 0:00:06

Progress:  71%|█████████████████████████████▎           |  ETA: 0:00:04

Progress:  87%|███████████████████████████████████▊     |  ETA: 0:00:02

Progress: 100%|█████████████████████████████████████████| Time: 0:00:12
In [12]:
wing_arrays = [convert(Array{Float64}, w) for w in wings]
70-element Vector{Matrix{Float64}}:
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 ⋮
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]
 [1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; … ; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813; 1.0000000074505813 1.0000000074505813 … 1.0000000074505813 1.0000000074505813]

3.2 Strategy 2: Radial filtration

Radial filtration

The radial filtration assigns each foreground pixel a value equal to its distance from the centroid of the wing. Sublevel-set persistence on this function captures how topological features (loops in the venation) are distributed from the center of the wing outward. This is complementary to the Rips filtration, which captures global loop structure without spatial information.

In [13]:
pds_radial = @showprogress "radial_pd_1d" map(wing_arrays) do A
    radial_pd_1d(A)
end;
radial_pd_1d   3%|█▏                                     |  ETA: 0:00:49

radial_pd_1d  59%|██████████████████████▉                |  ETA: 0:00:02

radial_pd_1d 100%|███████████████████████████████████████| Time: 0:00:02

We also compute H0 persistence for the radial filtration, capturing how disconnected vein segments merge as the radial sweep grows outward:

In [14]:
pds_radial_h0 = @showprogress "radial_pd_0d" map(wing_arrays) do A
    radial_pd_0d(A)
end;
radial_pd_0d   6%|██▎                                    |  ETA: 0:00:02

radial_pd_0d  67%|██████████████████████████▏            |  ETA: 0:00:01

radial_pd_0d 100%|███████████████████████████████████████| Time: 0:00:01

3.3 Visualizing the radial filtration

The radial filtration assigns each foreground pixel a value proportional to its distance from the wing centroid. Below we visualize the radial filtration arrays and the resulting persistence diagrams for one wing per family:

In [15]:
example_indices = [findfirst(==(f), families) for f in sort(unique(families))]

for i in example_indices[1:min(5, length(example_indices))]
    F_rad = radial_filtration(wing_arrays[i])

    # Only show foreground in the heatmap
    F_display = copy(F_rad)
    F_display[F_display .> 1.5] .= NaN  # mask background

    p1 = heatmap(F_display[end:-1:1, :],
                 color = :viridis, colorbar = true,
                 title = "Radial filtration",
                 aspect_ratio = :equal, xticks = false, yticks = false)

    p2 = heatmap(wing_arrays[i][end:-1:1, :],
                 color = :grays, colorbar = false,
                 title = "Wing image",
                 aspect_ratio = :equal, xticks = false, yticks = false)

    pers_rad_h1 = persistence.(pds_radial[i])
    p3 = isempty(pers_rad_h1) ? plot(title = "Radial H₁ (empty)") :
         bar(sort(pers_rad_h1, rev = true), title = "Radial H₁ (loops)",
             legend = false, ylabel = "persistence")

    pers_rad_h0 = [persistence(x) for x in pds_radial_h0[i] if isfinite(persistence(x))]
    p4 = isempty(pers_rad_h0) ? plot(title = "Radial H₀ (empty)") :
         bar(sort(pers_rad_h0, rev = true), title = "Radial H₀ (components)",
             legend = false, ylabel = "persistence")

    p = plot(p1, p2, p3, p4, layout = (2, 2), size = (900, 700),
             plot_title = "$(families[i])$(individuals[i])")
    display(p)
end;

3.4 Examples: persistence diagrams from each strategy

Below we show persistence diagrams from both Rips and radial filtrations for one specimen per family:

In [16]:
for i in example_indices
    pers_rips = persistence.(pds_rips[i])
    pers_rad = persistence.(pds_radial[i])
    pers_rad_h0 = [persistence(x) for x in pds_radial_h0[i] if isfinite(persistence(x))]

    p1 = isempty(pers_rips) ? plot(title = "Rips H₁ (empty)") :
         bar(sort(pers_rips, rev = true), title = "Rips H₁", legend = false, ylabel = "persistence")
    p2 = isempty(pers_rad) ? plot(title = "Radial H₁ (empty)") :
         bar(sort(pers_rad, rev = true), title = "Radial H₁", legend = false, ylabel = "persistence")
    p3 = isempty(pers_rad_h0) ? plot(title = "Radial H₀ (empty)") :
         bar(sort(pers_rad_h0, rev = true), title = "Radial H₀", legend = false, ylabel = "persistence")
    p4 = scatter(last.(samples[i]), first.(samples[i]),
                 aspect_ratio = :equal, markersize = 1, legend = false, title = "Point cloud")

    p = plot(p1, p2, p3, p4, layout = (2, 2), size = (900, 650),
             plot_title = "$(families[i]) ($(individuals[i]))")
    display(p)
end;

3.5 Extended summary statistics

We extract 19 summary statistics from each persistence diagram using pd_statistics_extended:

  • Count of intervals, max/total/total² persistence
  • Quantiles (10th, 25th, 50th, 75th, 90th)
  • Entropy, std of persistence
  • Skewness, kurtosis of persistence distribution
  • Median birth, median death, std birth, std death
  • Mean midlife = mean of (birth + death)/2
  • Persistence range = max - min persistence
In [17]:
using DecisionTree
using Random: MersenneTwister

stat_names_ext = [
    "count", "max_pers", "total_pers", "total_pers2",
    "q10", "q25", "median", "q75", "q90",
    "entropy", "std_pers",
    "skewness", "kurtosis",
    "median_birth", "median_death", "std_birth", "std_death",
    "mean_midlife", "pers_range"
]

stats_rips = collect(hcat([pd_statistics_extended(pd) for pd in pds_rips]...)')
stats_radial = collect(hcat([pd_statistics_extended(pd) for pd in pds_radial]...)')
stats_radial_h0 = collect(hcat([pd_statistics_extended(pd) for pd in pds_radial_h0]...)')

println("Statistics per diagram: $(length(stat_names_ext)) features")
println("  Rips H1: $(size(stats_rips))")
println("  Radial H1: $(size(stats_radial))")
println("  Radial H0: $(size(stats_radial_h0))")
Statistics per diagram: 19 features
  Rips H1: (70, 19)
  Radial H1: (70, 19)
  Radial H0: (70, 19)

3.5.1 Statistics comparison by family

In [18]:
stats_df = DataFrame(
    sample = individuals,
    family = families,
    n_intervals_rips = stats_rips[:, 1],
    max_pers_rips = stats_rips[:, 2],
    entropy_rips = stats_rips[:, 10],
    n_intervals_rad = stats_radial[:, 1],
    max_pers_rad = stats_radial[:, 2],
    entropy_rad = stats_radial[:, 10],
    skewness_rips = stats_rips[:, 12],
    median_death_rips = stats_rips[:, 15],
)

p1 = boxplot(stats_df.family, stats_df.n_intervals_rips,
             title = "Rips: # H₁ intervals", legend = false, ylabel = "count", xrotation = 45)
p2 = boxplot(stats_df.family, stats_df.max_pers_rips,
             title = "Rips: max persistence", legend = false, ylabel = "persistence", xrotation = 45)
p3 = boxplot(stats_df.family, stats_df.n_intervals_rad,
             title = "Radial: # H₁ intervals", legend = false, ylabel = "count", xrotation = 45)
p4 = boxplot(stats_df.family, stats_df.max_pers_rad,
             title = "Radial: max persistence", legend = false, ylabel = "persistence", xrotation = 45)
plot(p1, p2, p3, p4, layout = (2, 2), size = (1000, 700))

4 Classification

We build the feature matrix from the extended summary statistics of both filtrations:

In [19]:
labels = families

X_features = hcat(
    stats_rips,        # 19 features: Rips H1
    stats_radial,      # 19 features: Radial H1
    stats_radial_h0,   # 19 features: Radial H0
) |> sanitize_feature_matrix

feature_blocks = ["Rips_H1", "Radial_H1", "Radial_H0"]
feature_names = ["$(block)__$(stat)" for block in feature_blocks for stat in stat_names_ext]

println("Feature matrix: $(size(X_features))")
println("  $(size(X_features, 2)) features × $(size(X_features, 1)) samples")
println("  Feature-to-sample ratio: $(round(size(X_features, 2) / size(X_features, 1), digits=2))")
Feature matrix: (70, 57)
  57 features × 70 samples
  Feature-to-sample ratio: 0.81
Leave-one-out cross-validation (LOOCV)

With only 72 samples, we use leave-one-out cross-validation: for each sample, the classifier is trained on all other samples and tested on the held-out one. LOOCV has low bias (nearly the entire dataset is used for training) and is the standard validation strategy for small datasets.

4.1 Decision tree

We use a single decision tree as our most interpretable classifier. The tree structure itself provides readable classification rules:

In [20]:
function loocv_decision_tree(X::Matrix, y::Vector{String};
                             max_depth::Int = 6,
                             min_samples_leaf::Int = 2,
                             min_samples_split::Int = 2,
                             rng_seed::Int = 20260223)
    Xclean = sanitize_feature_matrix(X)
    n = size(Xclean, 1)
    predictions = Vector{String}(undef, n)

    for i in 1:n
        train_idx = setdiff(1:n, i)
        X_train = Xclean[train_idx, :]
        y_train = y[train_idx]

        tree = DecisionTree.build_tree(
            y_train,
            X_train,
            size(X_train, 2),
            max_depth,
            min_samples_leaf,
            min_samples_split,
            0.0;
            loss = DecisionTree.util.gini,
            rng = MersenneTwister(rng_seed + i),
            impurity_importance = true
        )

        predictions[i] = DecisionTree.apply_tree(tree, Xclean[i, :])
    end

    (accuracy = mean(predictions .== y), predictions = predictions)
end

tree_results = DataFrame(
    max_depth = Int[],
    min_samples_leaf = Int[],
    n_correct = Int[],
    accuracy = Float64[],
    balanced_accuracy = Float64[],
    macro_f1 = Float64[],
)

for max_depth in [3, 4, 5, 6, 8]
    for min_leaf in [1, 2, 3]
        r = loocv_decision_tree(X_features, labels;
                                max_depth = max_depth, min_samples_leaf = min_leaf, min_samples_split = 2)
        m = classification_metrics(labels, r.predictions)
        push!(tree_results, (
            max_depth, min_leaf,
            sum(r.predictions .== labels),
            r.accuracy, m.balanced_accuracy, m.macro_f1
        ))
    end
end

sort!(tree_results, :accuracy, rev = true)
first(tree_results, 10)
10×6 DataFrame
Row max_depth min_samples_leaf n_correct accuracy balanced_accuracy macro_f1
Int64 Int64 Int64 Float64 Float64 Float64
1 5 2 37 0.528571 0.54305 0.526573
2 5 3 36 0.514286 0.532949 0.512004
3 6 2 36 0.514286 0.529161 0.5107
4 6 3 36 0.514286 0.532949 0.512004
5 8 2 36 0.514286 0.529161 0.5107
6 8 3 36 0.514286 0.532949 0.512004
7 5 1 35 0.5 0.515272 0.500981
8 6 1 34 0.485714 0.501383 0.484701
9 8 1 34 0.485714 0.501383 0.485786
10 4 2 30 0.428571 0.456229 0.406526
In [21]:
best_tree = tree_results[1, :]

tree_model = DecisionTree.build_tree(
    labels, X_features, size(X_features, 2),
    best_tree.max_depth, best_tree.min_samples_leaf, 2, 0.0;
    loss = DecisionTree.util.gini,
    rng = MersenneTwister(20260223),
    impurity_importance = true
)

tree_importance = DecisionTree.impurity_importance(tree_model; normalize = true)

tree_importance_df = DataFrame(
    feature = feature_names,
    importance = tree_importance
)
sort!(tree_importance_df, :importance, rev = true)

println("Best Decision Tree LOOCV: $(best_tree.n_correct)/$(length(labels)) ($(round(best_tree.accuracy * 100, digits = 1))%)")
println("Balanced accuracy: $(round(best_tree.balanced_accuracy * 100, digits = 1))%")

first(filter(:importance => >(0.0), tree_importance_df), 15)
Best Decision Tree LOOCV: 37/70 (52.9%)
Balanced accuracy: 54.3%
12×2 DataFrame
Row feature importance
String Float64
1 Rips_H1__entropy 0.235118
2 Rips_H1__q10 0.136255
3 Radial_H0__entropy 0.122901
4 Rips_H1__median_birth 0.105657
5 Radial_H0__total_pers 0.104808
6 Radial_H0__pers_range 0.0826186
7 Rips_H1__skewness 0.0815281
8 Radial_H0__q10 0.0752096
9 Rips_H1__max_pers 0.0152865
10 Rips_H1__std_death 0.0137579
11 Radial_H0__count 0.0137579
12 Radial_H0__median 0.0131027
In [22]:
topk = min(12, nrow(filter(:importance => >(0.0), tree_importance_df)))
top_tree_imp = first(filter(:importance => >(0.0), tree_importance_df), topk)

bar(
    top_tree_imp.feature,
    top_tree_imp.importance,
    xlabel = "Feature",
    ylabel = "Normalized impurity importance",
    title = "Decision tree feature importance (top $(topk))",
    legend = false,
    xrotation = 45,
    size = (1100, 550),
)
In [23]:
# Print the tree structure for interpretability
println("Decision tree structure:")
print_tree(tree_model, feature_names = feature_names)
Decision tree structure:
Feature 10: "Rips_H1__entropy" < 2.347 ?
├─ Feature 48: "Radial_H0__entropy" < 4.424 ?
    ├─ Feature 41: "Radial_H0__total_pers" < 25.63 ?
        ├─ Feature 17: "Rips_H1__std_death" < 13.62 ?
            ├─ Simuliidae : 6/6
            └─ Sciaridae : 1/2
        └─ Feature 14: "Rips_H1__median_birth" < 7.999 ?
            ├─ Feature 12: "Rips_H1__skewness" < -0.04769 ?
                ├─ Sciaridae : 5/5
                └─ Bibionidae : 4/4
            └─ Feature 39: "Radial_H0__count" < 156.0 ?
                ├─ Chironomidae : 6/6
                └─ Chironomidae : 1/2
    └─ Feature 14: "Rips_H1__median_birth" < 7.713 ?
        ├─ Ceratopogonidae : 1/2
        └─ Ceratopogonidae : 7/7
└─ Feature 5: "Rips_H1__q10" < 7.187 ?
    ├─ Feature 57: "Radial_H0__pers_range" < 1.804 ?
        ├─ Feature 45: "Radial_H0__median" < 0.001287 ?
            ├─ Asilidae : 1/2
            └─ Asilidae : 5/5
        └─ Feature 43: "Radial_H0__q10" < 0.0003789 ?
            ├─ Feature 10: "Rips_H1__entropy" < 2.463 ?
                ├─ Asilidae : 1/2
                └─ Tipulidae : 10/10
            └─ Rhagionidae : 3/3
    └─ Feature 10: "Rips_H1__entropy" < 2.593 ?
        ├─ Rhagionidae : 1/2
        └─ Feature 2: "Rips_H1__max_pers" < 39.98 ?
            ├─ Tabanidae : 10/10
            └─ Asilidae : 1/2

4.2 LDA (Linear Discriminant Analysis)

Linear Discriminant Analysis (LDA)

LDA finds a linear projection of the feature space that maximizes the ratio of between-class variance to within-class variance. The projected data is then classified with a simple 1-NN rule. LDA is a classical method that works well when classes are approximately Gaussian and the number of features is not too large relative to the number of samples.

In [24]:
r_lda = loocv_lda(X_features, labels)
m_lda = classification_metrics(labels, r_lda.predictions)
println("LDA LOOCV: $(sum(r_lda.predictions .== labels))/$(length(labels)) ($(round(r_lda.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(m_lda.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(m_lda.macro_f1 * 100, digits=1))%")
LDA LOOCV: 46/70 (65.7%)
Balanced accuracy: 65.4%
Macro-F1: 64.3%

4.3 Balanced Random Forest

Random Forest

A Random Forest is an ensemble of decision trees, each trained on a bootstrap sample of the data using a random subset of features. The final prediction is the majority vote across all trees. Balanced Random Forests oversample minority classes (or weight them inversely to their frequency) so that rare families are not drowned out by common ones — important here because Tipulidae has 12 samples while some families have only 2–3. Random Forests are robust to overfitting, handle high-dimensional features well, and provide built-in feature importance estimates.

In [25]:
rf_grid = []
for n_trees in [100, 200, 500, 1000]
    for max_depth in [-1, 8, 12]
        for min_leaf in [1, 2]
            r = loocv_random_forest_balanced(X_features, labels;
                    n_trees = n_trees, rng_seed = 20260223)
            m = classification_metrics(labels, r.predictions)
            push!(rf_grid, (
                n_trees = n_trees, max_depth = max_depth, min_leaf = min_leaf,
                n_correct = sum(r.predictions .== labels),
                accuracy = r.accuracy,
                balanced_accuracy = m.balanced_accuracy,
                macro_f1 = m.macro_f1,
            ))
        end
    end
end

rf_grid_df = DataFrame(rf_grid)
sort!(rf_grid_df, :accuracy, rev = true)
first(rf_grid_df, 8)
8×7 DataFrame
Row n_trees max_depth min_leaf n_correct accuracy balanced_accuracy macro_f1
Int64 Int64 Int64 Int64 Float64 Float64 Float64
1 200 -1 1 52 0.742857 0.739899 0.726137
2 200 -1 2 52 0.742857 0.739899 0.726137
3 200 8 1 52 0.742857 0.739899 0.726137
4 200 8 2 52 0.742857 0.739899 0.726137
5 200 12 1 52 0.742857 0.739899 0.726137
6 200 12 2 52 0.742857 0.739899 0.726137
7 1000 -1 1 52 0.742857 0.739899 0.722776
8 1000 -1 2 52 0.742857 0.739899 0.722776
In [26]:
best_rf_row = rf_grid_df[1, :]
r_rf = loocv_random_forest_balanced(X_features, labels;
            n_trees = best_rf_row.n_trees, rng_seed = 20260223)
m_rf = classification_metrics(labels, r_rf.predictions)
println("Best Balanced RF LOOCV: $(sum(r_rf.predictions .== labels))/$(length(labels)) ($(round(r_rf.accuracy * 100, digits=1))%)")
println("  n_trees=$(best_rf_row.n_trees)")
println("Balanced accuracy: $(round(m_rf.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(m_rf.macro_f1 * 100, digits=1))%")
Best Balanced RF LOOCV: 52/70 (74.3%)
  n_trees=200
Balanced accuracy: 74.0%
Macro-F1: 72.6%

4.4 SVM

Support Vector Machine (SVM)

An SVM finds the hyperplane that maximizes the margin between classes. The RBF (Radial Basis Function) kernel maps data into a high-dimensional space where linear separation becomes possible, controlled by a regularization parameter \(C\) (penalty for misclassification): small \(C\) allows wider margins with more misclassifications, large \(C\) enforces tight boundaries. The linear kernel finds a separating hyperplane directly in the original feature space and is less prone to overfitting when \(p \gg n\).

In [27]:
svm_results = []
for kernel in [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear]
    for cost in [0.01, 0.1, 1.0, 10.0, 100.0]
        kernel_name = kernel == LIBSVM.Kernel.RadialBasis ? "RBF" : "Linear"
        r = loocv_svm(X_features, labels; kernel = kernel, cost = cost)
        m = classification_metrics(labels, r.predictions)
        push!(svm_results, (
            method = "SVM ($kernel_name, C=$cost)",
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy,
            balanced_accuracy = m.balanced_accuracy,
            macro_f1 = m.macro_f1,
        ))
    end
end

svm_df = DataFrame(svm_results)
sort!(svm_df, :accuracy, rev = true)
first(svm_df, 8)
8×6 DataFrame
Row method n_correct n_total accuracy balanced_accuracy macro_f1
String Int64 Int64 Float64 Float64 Float64
1 SVM (Linear, C=0.1) 48 70 0.685714 0.645623 0.640102
2 SVM (Linear, C=1.0) 46 70 0.657143 0.626263 0.627028
3 SVM (Linear, C=10.0) 46 70 0.657143 0.626263 0.627028
4 SVM (Linear, C=100.0) 46 70 0.657143 0.626263 0.627028
5 SVM (RBF, C=10.0) 44 70 0.628571 0.598485 0.5844
6 SVM (RBF, C=100.0) 44 70 0.628571 0.598485 0.5844
7 SVM (RBF, C=1.0) 41 70 0.585714 0.526094 0.501635
8 SVM (Linear, C=0.01) 39 70 0.557143 0.489899 0.469207
In [28]:
best_svm_row = svm_df[1, :]
println("Best SVM: $(best_svm_row.method)")
println("  $(best_svm_row.n_correct)/$(best_svm_row.n_total) ($(round(best_svm_row.accuracy * 100, digits=1))%)")
println("  Balanced accuracy: $(round(best_svm_row.balanced_accuracy * 100, digits=1))%")
println("  Macro-F1: $(round(best_svm_row.macro_f1 * 100, digits=1))%")
Best SVM: SVM (Linear, C=0.1)
  48/70 (68.6%)
  Balanced accuracy: 64.6%
  Macro-F1: 64.0%

4.5 k-NN on Rips Wasserstein distance

Wasserstein distance between persistence diagrams

The Wasserstein distance \(W_q\) between two persistence diagrams is the cost of the optimal matching between their points (including matching points to the diagonal, representing trivial features). With \(q=1\) it equals the Earth Mover’s Distance; with \(q=2\) it penalizes large mismatches more heavily.

k-Nearest Neighbors (k-NN) classifies a query by majority vote among its \(k\) nearest neighbors in the distance matrix. With \(k=1\), this is the simplest possible classifier — completely hyperparameter-free — and serves as a useful baseline.

As a complementary approach, we compute pairwise Wasserstein distances between the Rips H1 persistence diagrams and classify with k-NN. Unlike the feature-based classifiers above, this operates directly on the persistence diagrams without extracting summary statistics, and is therefore less susceptible to information loss during featurization.

In [29]:
D_wass1_rips = wasserstein_distance_matrix(pds_rips, q = 1)
D_wass2_rips = wasserstein_distance_matrix(pds_rips, q = 2)

knn_wass_results = []
for (name, D) in [("Wass-1", D_wass1_rips), ("Wass-2", D_wass2_rips)]
    for k in [1, 3, 5]
        r = loocv_knn(D, labels; k = k)
        m = classification_metrics(labels, r.predictions)
        push!(knn_wass_results, (
            method = "$(k)-NN Rips $name",
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy,
            balanced_accuracy = m.balanced_accuracy,
            macro_f1 = m.macro_f1,
        ))
    end
end

knn_wass_df = DataFrame(knn_wass_results)
sort!(knn_wass_df, :accuracy, rev = true)
knn_wass_df
6×6 DataFrame
Row method n_correct n_total accuracy balanced_accuracy macro_f1
String Int64 Int64 Float64 Float64 Float64
1 1-NN Rips Wass-1 48 70 0.685714 0.656566 0.647072
2 5-NN Rips Wass-2 47 70 0.671429 0.638889 0.620491
3 3-NN Rips Wass-1 46 70 0.657143 0.610269 0.589769
4 5-NN Rips Wass-1 45 70 0.642857 0.583333 0.537277
5 1-NN Rips Wass-2 44 70 0.628571 0.587121 0.579056
6 3-NN Rips Wass-2 44 70 0.628571 0.58165 0.553934
In [30]:
# Visualise one of the Wasserstein distance matrices
p_wass = heatmap(D_wass1_rips,
    xticks = (1:length(individuals), individuals),
    yticks = (1:length(individuals), individuals),
    title = "Rips Wasserstein-1 distance",
    color = :viridis, xrotation = 90, size = (750, 700))
display(p_wass)

4.6 Summary of all classifiers

In [31]:
all_results = []

# Decision Tree
push!(all_results, (method = "Decision Tree (d=$(best_tree.max_depth))",
    n_correct = best_tree.n_correct, n_total = length(labels),
    accuracy = best_tree.accuracy, balanced_accuracy = best_tree.balanced_accuracy))

# LDA
push!(all_results, (method = "LDA",
    n_correct = sum(r_lda.predictions .== labels), n_total = length(labels),
    accuracy = r_lda.accuracy, balanced_accuracy = m_lda.balanced_accuracy))

# RF
push!(all_results, (method = "Balanced RF (T=$(best_rf_row.n_trees))",
    n_correct = sum(r_rf.predictions .== labels), n_total = length(labels),
    accuracy = r_rf.accuracy, balanced_accuracy = m_rf.balanced_accuracy))

# Best SVM
best_svm = svm_df[1, :]
push!(all_results, (method = best_svm.method,
    n_correct = best_svm.n_correct, n_total = best_svm.n_total,
    accuracy = best_svm.accuracy, balanced_accuracy = best_svm.balanced_accuracy))

# k-NN on Wasserstein (best per distance)
for name in ["Wass-1", "Wass-2"]
    sub = filter(:method => m -> contains(m, name), knn_wass_df)
    best_knn = sub[1, :]  # already sorted by accuracy
    push!(all_results, (method = best_knn.method,
        n_correct = best_knn.n_correct, n_total = best_knn.n_total,
        accuracy = best_knn.accuracy, balanced_accuracy = best_knn.balanced_accuracy))
end

results_df = DataFrame(all_results)
sort!(results_df, :accuracy, rev = true)
results_df
6×5 DataFrame
Row method n_correct n_total accuracy balanced_accuracy
String Int64 Int64 Float64 Float64
1 Balanced RF (T=200) 52 70 0.742857 0.739899
2 SVM (Linear, C=0.1) 48 70 0.685714 0.645623
3 1-NN Rips Wass-1 48 70 0.685714 0.656566
4 5-NN Rips Wass-2 47 70 0.671429 0.638889
5 LDA 46 70 0.657143 0.65374
6 Decision Tree (d=5) 37 70 0.528571 0.54305

5 Which features drive the classification?

In [32]:
# ── Feature importance from full-data RF ────────────────────────────────────
# Build individual trees with impurity_importance tracking enabled,
# then aggregate importances across the forest.
rng_imp = MersenneTwister(20260223)
n_feat_imp = max(1, round(Int, sqrt(size(X_features, 2))))
n_trees_imp = 1000

tree_importances = zeros(size(X_features, 2))
for t in 1:n_trees_imp
    # Bootstrap sample (with replacement)
    n = size(X_features, 1)
    idx = rand(rng_imp, 1:n, n)
    X_boot = X_features[idx, :]
    y_boot = labels[idx]

    tree = DecisionTree.build_tree(
        y_boot, X_boot, n_feat_imp, -1, 1, 2, 0.0;
        loss = DecisionTree.util.gini,
        rng = rng_imp,
        impurity_importance = true
    )
    tree_importances .+= DecisionTree.impurity_importance(tree; normalize = false)
end

# Normalize
rf_imp = tree_importances ./ maximum(tree_importances)

imp_df = DataFrame(
    feature = feature_names,
    importance = rf_imp
)
sort!(imp_df, :importance, rev = true)
top_imp = first(filter(:importance => >(0.0), imp_df), 20)
top_imp
20×2 DataFrame
Row feature importance
String Float64
1 Rips_H1__entropy 1.0
2 Rips_H1__total_pers 0.949461
3 Radial_H0__count 0.926215
4 Radial_H0__entropy 0.907955
5 Rips_H1__max_pers 0.835533
6 Radial_H0__total_pers 0.798282
7 Rips_H1__median_birth 0.795691
8 Rips_H1__total_pers2 0.733267
9 Radial_H0__std_birth 0.726678
10 Radial_H0__total_pers2 0.670079
11 Rips_H1__q90 0.621246
12 Rips_H1__q75 0.619905
13 Radial_H0__q25 0.610705
14 Rips_H1__count 0.602051
15 Radial_H0__q10 0.598654
16 Rips_H1__std_death 0.596886
17 Rips_H1__std_pers 0.574682
18 Rips_H1__pers_range 0.512191
19 Rips_H1__q10 0.474335
20 Radial_H0__median_birth 0.472743
In [33]:
topk_imp = min(15, nrow(top_imp))

bar(
    top_imp.feature[1:topk_imp],
    top_imp.importance[1:topk_imp],
    xlabel = "Feature",
    ylabel = "Normalized impurity importance",
    title = "RF feature importance (top $(topk_imp))",
    legend = false,
    xrotation = 45,
    size = (1100, 550),
)

6 Feature ablation: Rips vs Radial

In [34]:
# ── Test each filtration alone and combined ─────────────────────────────────
ablation_sets = [
    ("Rips H1 only",              hcat(stats_rips)),
    ("Radial H1 only",            hcat(stats_radial)),
    ("Radial H0 only",            hcat(stats_radial_h0)),
    ("Radial H0 + H1",            hcat(stats_radial, stats_radial_h0)),
    ("Rips + Radial H0",          hcat(stats_rips, stats_radial_h0)),
    ("Rips + Radial H1",          hcat(stats_rips, stats_radial)),
    ("Rips + Radial H0+H1 (all)", X_features),
]

ablation_results = []
for (name, X_abl) in ablation_sets
    X_abl_clean = sanitize_feature_matrix(X_abl)

    r_lda_abl = loocv_lda(X_abl_clean, labels)
    r_rf_abl = loocv_random_forest_balanced(X_abl_clean, labels; n_trees=500, rng_seed=20260223)

    push!(ablation_results, (
        filtrations = name,
        n_features = size(X_abl_clean, 2),
        lda_correct = sum(r_lda_abl.predictions .== labels),
        lda_accuracy = round(r_lda_abl.accuracy * 100, digits=1),
        rf_correct = sum(r_rf_abl.predictions .== labels),
        rf_accuracy = round(r_rf_abl.accuracy * 100, digits=1),
    ))
end

ablation_df = DataFrame(ablation_results)
sort!(ablation_df, :lda_accuracy, rev = true)
ablation_df
7×6 DataFrame
Row filtrations n_features lda_correct lda_accuracy rf_correct rf_accuracy
String Int64 Int64 Float64 Int64 Float64
1 Rips + Radial H0 38 49 70.0 51 72.9
2 Rips + Radial H0+H1 (all) 57 46 65.7 51 72.9
3 Radial H0 only 19 43 61.4 40 57.1
4 Rips + Radial H1 38 41 58.6 47 67.1
5 Radial H0 + H1 38 40 57.1 41 58.6
6 Rips H1 only 19 37 52.9 45 64.3
7 Radial H1 only 19 10 14.3 16 22.9

7 Honest evaluation (Nested LOOCV)

Nested cross-validation

Standard LOOCV can give optimistically biased estimates when hyperparameters are tuned on the same data. Nested LOOCV adds an inner cross-validation loop: for each held-out test sample, the best hyperparameters are selected using only the training fold. This provides an unbiased estimate of generalization performance.

In [35]:
nested_rf = nested_loocv_random_forest(
    X_features, labels;
    n_trees_grid = [200, 500],
    max_depth_grid = [-1],
    min_samples_leaf_grid = [1, 2],
    inner_folds = 4,
    balanced = true,
    rng_seed = 20260223
)
n_correct_nested = sum(nested_rf.predictions .== labels)

println("=== Nested LOOCV: Balanced RF ===")
println("Features: $(size(X_features, 2)) (Rips + Radial stats)")
println("Accuracy: $(n_correct_nested)/$(length(labels)) ($(round(nested_rf.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_rf.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_rf.macro_f1 * 100, digits=1))%")

ci_nested = wilson_ci(n_correct_nested, length(labels))
println("95% Wilson CI: [$(round(ci_nested.lower * 100, digits=1))%, $(round(ci_nested.upper * 100, digits=1))%]")
=== Nested LOOCV: Balanced RF ===
Features: 57 (Rips + Radial stats)
Accuracy: 44/70 (62.9%)
Balanced accuracy: 62.7%
Macro-F1: 61.1%
95% Wilson CI: [51.1%, 73.2%]

7.1 Confusion matrix

In [36]:
cm_nested = confusion_matrix(labels, nested_rf.predictions)
classes_nested = cm_nested.classes


println("Per-class accuracy (Nested LOOCV):")
for (i, cls) in enumerate(classes_nested)
    correct = cm_nested.matrix[i, i]
    total = sum(cm_nested.matrix[i, :])
    println("  $(cls): $(correct)/$(total) ($(round(correct / total * 100, digits=1))%)")
end
Per-class accuracy (Nested LOOCV):
  Asilidae: 5/8 (62.5%)
  Bibionidae: 5/6 (83.3%)
  Ceratopogonidae: 7/8 (87.5%)
  Chironomidae: 2/8 (25.0%)
  Rhagionidae: 1/4 (25.0%)
  Sciaridae: 4/6 (66.7%)
  Simuliidae: 7/7 (100.0%)
  Tabanidae: 8/11 (72.7%)
  Tipulidae: 5/12 (41.7%)
In [37]:
heatmap(cm_nested.matrix,
        xticks = (1:length(classes_nested), classes_nested),
        yticks = (1:length(classes_nested), classes_nested),
        xlabel = "Predicted", ylabel = "True",
        title = "Confusion Matrix (Nested LOOCV — Balanced RF)",
        color = :Blues,
        clims = (0, maximum(cm_nested.matrix)),
        xrotation = 45, size = (700, 600))

7.2 Confidence interval for best classifier

In [38]:
best_row = results_df[1, :]
println("=== Best Method ===")
println("$(best_row.method): $(best_row.n_correct)/$(best_row.n_total) ($(round(best_row.accuracy * 100, digits=1))%)")
ci = wilson_ci(best_row.n_correct, best_row.n_total)
println("95% Wilson CI: [$(round(ci.lower * 100, digits=1))%, $(round(ci.upper * 100, digits=1))%]")
=== Best Method ===
Balanced RF (T=200): 52/70 (74.3%)
95% Wilson CI: [63.0%, 83.1%]

8 Discussion

We applied two TDA filtration strategies — Vietoris-Rips and radial — to classify Diptera families from wing venation images, extracting 19 extended summary statistics per persistence diagram.

8.1 Key findings

  1. Two filtrations capture complementary information: The Vietoris-Rips filtration on point-cloud samples captures the global loop structure of the wing venation (number and prominence of wing cells). The radial filtration captures the center-to-periphery organization: how veins and cells are arranged spatially from the wing base outward.

  2. Extended summary statistics are sufficient: The 19-feature extended statistics (count, max/total persistence, quantiles, entropy, skewness, kurtosis, median birth/death, etc.) capture the essential information from each persistence diagram. With 3 diagrams × 19 features = 57 total features for 72 samples, the feature-to-sample ratio stays reasonable (~0.8:1), reducing overfitting risk.

  3. Feature ablation reveals which filtrations matter: The ablation study shows whether Rips alone, radial alone, or the combination gives the best performance. This provides evidence about whether global topology (Rips) or spatial organization (radial) is more discriminative.

  4. Why other filtrations were dropped:

    • Directional (height) filtrations: 8 sweep directions × H0+H1 produced a large feature set dominated by noise. On noisy binarized images with isolated pixels and incomplete segmentation, each sweep direction generates spurious topological features.
    • EDT (Euclidean Distance Transform): On binarized images, the EDT is trivially related to the binary structure, providing little additional information beyond what Rips already captures.
    • Cubical (grayscale sublevel-set): After binarization, the grayscale information is lost, so cubical persistence reduces to computing persistence on a binary image — equivalent to connected-component analysis.
  5. Nested LOOCV provides honest evaluation: Standard LOOCV can be optimistic when hyperparameters are tuned on the same data. Nested LOOCV (with 4-fold inner CV for hyperparameter selection) gives unbiased accuracy estimates.

  6. Statistical rigor: We report LOOCV accuracy with Wilson confidence intervals, and nested LOOCV for unbiased evaluation.

8.2 Limitations

  • Class imbalance: Tipulidae has 12 samples while some families have only 2–3, which affects classifier performance.
  • Image quality and preprocessing parameters (blur, threshold) influence topological features.
  • With only 72 samples, confidence intervals remain wide regardless of method.
  • Wings are manually segmented and binarized; automated segmentation could introduce different error patterns.

8.3 Future work

  • Extend dataset with more specimens per family, especially underrepresented ones
  • Improve imaging/segmentation quality to reduce noise
  • Apply extended persistence or zigzag persistence for richer invariants
  • Investigate which specific topological features (e.g., how many loops, persistence of largest features) correspond to known vein characters in Diptera taxonomy
  • Try the analysis on non-binarized (grayscale) images, where EDT and cubical filtrations would be more informative